//+------------------------------------------------------------------+
//|                                                     LDA Test.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "https://www.mql5.com/en/users/omegajoctan"
#property version   "1.00"

//--- Machine learning libraries

#include <LDA.mqh>
#include <PCA.mqh>
#include <tree.mqh>
#include <Metrics.mqh>

CLDA *lda;
CPCA *pca;
CDecisionTreeClassifier *classifier_tree;

//--- Trading libraries
#define MAGICNUMBER 10022024

#include <Trade\Trade.mqh>
#include <Trade\PositionInfo.mqh>

CTrade m_trade;
CPositionInfo m_position;

enum dimension_reduction_enum
  {
    LDA, PCA
  };
  
input group "Machine Learning"
input int bars = 1000;
input int random_state_ = 42;
input double training_sample_size = 0.7;
input dimension_reduction_enum dimension_reduction = LDA;

input group "Trade Params";
input double stoploss = 350;
input double takeprofit = 450;
input uint slippage = 100;

int indicator_handle[5];
matrix dataset(bars, 6); //13 x variables 1 target variable

bool train_once =false;
int prev_bars = 0;
MqlTick ticks;
//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
  
//--- Trend following indicators

    indicator_handle[0] = iAMA(Symbol(), PERIOD_CURRENT, 9 , 2 , 30, 0, PRICE_OPEN);
    indicator_handle[1] = iADX(Symbol(), PERIOD_CURRENT, 14);
    indicator_handle[2] = iADXWilder(Symbol(), PERIOD_CURRENT, 14);
    indicator_handle[3] = iBands(Symbol(), PERIOD_CURRENT, 20, 0, 2.0, PRICE_OPEN);
    indicator_handle[4] = iDEMA(Symbol(), PERIOD_CURRENT, 14, 0, PRICE_OPEN);
       
//---
   
   TesterHideIndicators(true);
   
   train_once =false;
   
   m_trade.SetExpertMagicNumber(MAGICNUMBER);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetMarginMode();
   m_trade.SetTypeFillingBySymbol(Symbol());
   
   ChartRedraw();
   
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   while (CheckPointer(pca)!=POINTER_INVALID)
    delete pca;
    
   while (CheckPointer(lda)!=POINTER_INVALID)
    delete lda;
    
   while (CheckPointer(classifier_tree)!=POINTER_INVALID)
    delete classifier_tree;
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---

   if (!train_once) //call the function to train the model once on the program lifetime
    {
      TrainTest();
      train_once = true; 
    }

//--- 
   
   vector inputs(indicator_handle.Size());
   vector buffer;
   
   for (uint i=0; i<indicator_handle.Size(); i++)
     {
       buffer.CopyIndicatorBuffer(indicator_handle[i], 0, 0, 1); //copy the current indicator value
       inputs[i] = buffer[0]; //add its value to the inputs vector 
     }

//---
    
    SymbolInfoTick(Symbol(), ticks);
      
     if (isnewBar(PERIOD_CURRENT)) // We want to trade on the bar opening 
      {
           vector transformed_inputs = {};
            switch(dimension_reduction) //transform each and every new data to fit the dimensions selected during training
              {
               case  LDA:
                  transformed_inputs = lda.transform(inputs); //Transform the new data   
                 break;
               case PCA:
                  transformed_inputs = pca.transform(inputs);
                 break;
              }
//---

           int signal = (int)classifier_tree.predict(transformed_inputs);
           double min_lot = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
           SymbolInfoTick(Symbol(), ticks);
           
            if (signal == -1)
              {
                 if (!PosExists(MAGICNUMBER, POSITION_TYPE_SELL)) // If a sell trade doesnt exist
                   m_trade.Sell(min_lot, Symbol(), ticks.bid, ticks.bid+stoploss*Point(), ticks.bid - takeprofit*Point());
              }
            else
              {
                if (!PosExists(MAGICNUMBER, POSITION_TYPE_BUY))  // If a buy trade doesnt exist
                  m_trade.Buy(min_lot, Symbol(), ticks.ask, ticks.ask-stoploss*Point(), ticks.ask + takeprofit*Point());
              }
      }
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void TrainTest()
 {
   vector buffer = {};
   for (int i=0; i<ArraySize(indicator_handle); i++)
    {
      buffer.CopyIndicatorBuffer(indicator_handle[i], 0, 0, bars); //copy indicator buffer
      dataset.Col(buffer, i); //add the indicator buffer values to the dataset matrix 
    }

//---
  
   vector y(bars);
   MqlRates rates[];
   CopyRates(Symbol(), PERIOD_CURRENT,0,bars, rates);
   for (int i=0; i<bars; i++) //Creating the target variable 
     {
       if (rates[i].close > rates[i].open) //if bullish candle assign 1 to the y variable else assign the 0 class
        y[i] = 1;
       else
        y[0] = 0;
     }  
     
//---
   
   dataset.Col(y, dataset.Cols()-1); //add the y variable to the last column
   
//---

   matrix x_train, x_test;
   vector y_train, y_test;
      
   MatrixExtend::TrainTestSplitMatrices(dataset,x_train,y_train,x_test,y_test,training_sample_size,random_state_);
   
   matrix x_transformed = {};
   switch(dimension_reduction)
     {
      case  LDA:
         
         lda = new CLDA(NULL);
         
         x_transformed = lda.fit_transform(x_train, y_train); //Transform the training data   
         
        break;
      case PCA:
      
         pca = new CPCA(NULL);
         
         x_transformed = pca.fit_transform(x_train);
         
        break;
     }
     
   
   classifier_tree = new CDecisionTreeClassifier();
   classifier_tree.fit(x_transformed, y_train); //Train the model using the transformed data
   
   vector preds = classifier_tree.predict(x_transformed); //Make predictions using the transformed data
   
   Print("Train accuracy: ",Metrics::confusion_matrix(y_train, preds).accuracy);
   
   switch(dimension_reduction)
     {
      case  LDA:
        
        x_transformed = lda.transform(x_test); //Transform the testing data   
        
        break;
        
      case PCA:
        
        x_transformed = pca.transform(x_test); 
        
        break;
     }
   preds = classifier_tree.predict(x_transformed);
   
   Print("Test accuracy: ",Metrics::confusion_matrix(y_test, preds).accuracy);
   
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+

bool PosExists(int magic, ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == magic && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool isnewBar(ENUM_TIMEFRAMES TF)
 {
   if (prev_bars == 0)
      prev_bars = Bars(Symbol(), TF);
      
   
   if (prev_bars != Bars(Symbol(), TF))
    { 
      prev_bars = Bars(Symbol(), TF);
      return true;
    }
    
  return false;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
